{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "56161d6b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running as a Jupyter notebook - intended for development only!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_35643/572068249.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
      "  ipython.magic(\"load_ext autoreload\")\n",
      "/tmp/ipykernel_35643/572068249.py:22: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
      "  ipython.magic(\"autoreload 2\")\n"
     ]
    }
   ],
   "source": [
    "# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
    "DEVELOPMENT_MODE = False\n",
    "try:\n",
    "    import google.colab\n",
    "    IN_COLAB = True\n",
    "    print(\"Running as a Colab notebook\")\n",
    "    %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git``\n",
    "    %pip install circuitsvis\n",
    "    %pip install torchtyping\n",
    "    \n",
    "    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
    "    # # Install another version of node that makes PySvelte work way faster\n",
    "    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
    "    # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n",
    "except:\n",
    "    IN_COLAB = False\n",
    "    print(\"Running as a Jupyter notebook - intended for development only!\")\n",
    "    from IPython import get_ipython\n",
    "\n",
    "    ipython = get_ipython()\n",
    "    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
    "    ipython.magic(\"load_ext autoreload\")\n",
    "    ipython.magic(\"autoreload 2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "da9f5a40",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2023-10-23 17:45:03,881] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    }
   ],
   "source": [
    "# Import stuff\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "import einops\n",
    "from fancy_einsum import einsum\n",
    "import tqdm.auto as tqdm\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from pathlib import Path\n",
    "import plotly.express as px\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from torchtyping import TensorType as TT\n",
    "from typing import List, Union, Optional\n",
    "from jaxtyping import Float, Int\n",
    "from functools import partial\n",
    "import copy\n",
    "\n",
    "import itertools\n",
    "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n",
    "import dataclasses\n",
    "import datasets\n",
    "from IPython.display import HTML\n",
    "# import circuitsvis as cv\n",
    "\n",
    "import transformer_lens\n",
    "import transformer_lens.utils as utils\n",
    "from transformer_lens.hook_points import (\n",
    "    HookedRootModule,\n",
    "    HookPoint,\n",
    ")  # Hooking utilities\n",
    "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n",
    "\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
    "    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n",
    "\n",
    "def line(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
    "    px.line(utils.to_numpy(tensor), labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n",
    "\n",
    "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", renderer=None, **kwargs):\n",
    "    x = utils.to_numpy(x)\n",
    "    y = utils.to_numpy(y)\n",
    "    px.scatter(y=y, x=x, labels={\"x\":xaxis, \"y\":yaxis, \"color\":caxis}, **kwargs).show(renderer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "29c87845",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load hf model\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bigscience/bloom-560m\")\n",
    "model = AutoModelForCausalLM.from_pretrained(\"bigscience/bloom-560m\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1f7ac1e1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Missing key for a weight matrix in pretrained, filled in with an empty tensor: pos_embed.W_pos\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model bloom-560m into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "# Disable folding norms and folding norms and biases so that intermediate value\n",
    "# in between transformer blocks can be compared\n",
    "bloom = HookedTransformer.from_pretrained(\"bloom-560m\",fold_ln=False, fold_value_biases=False, center_writing_weights=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1490c1cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "avg logits difference: -1.5987565348041244e-05\n",
      "max logits difference: 0.0006103515625\n"
     ]
    }
   ],
   "source": [
    "text = '''\n",
    "TransformerLens lets you load in 50+ different open source language models,\n",
    "and exposes the internal activations of the model to you. You can cache\n",
    "any internal activation in the model, and add in functions to edit, remove\n",
    "or replace these activations as the model runs.\n",
    "'''\n",
    "input_ids = tokenizer(text, return_tensors='pt')['input_ids']\n",
    "gt_logits = model(input_ids)['logits'] # ground truth logits from hf\n",
    "my_logits = bloom(input_ids)\n",
    "centered_gt_logits = gt_logits - gt_logits.mean(-1, keepdim=True)\n",
    "mean_diff = (my_logits.cpu() - centered_gt_logits).mean()\n",
    "print(\"avg logits difference:\", mean_diff.item())\n",
    "max_diff = (my_logits.cpu() - centered_gt_logits).abs().max()\n",
    "print(\"max logits difference:\", max_diff.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "48806097",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "***** Matching hf and T-Lens residual stream in between transformer blocks *****\n",
      "layer 2 \t not close, max difference: 1.52587890625e-05\n",
      "layer 3 \t not close, max difference: 1.8358230590820312e-05\n",
      "layer 19 \t not close, max difference: 3.0517578125e-05\n",
      "layer 20 \t not close, max difference: 3.0517578125e-05\n",
      "layer 21 \t not close, max difference: 3.0517578125e-05\n",
      "layer 22 \t not close, max difference: 6.103515625e-05\n",
      "layer 23 \t not close, max difference: 6.866455078125e-05\n",
      "***** \ttesting with atol=0.001 and rtol=0.001\t *****\n",
      "All layers match with atol=0.001 rtol=0.001\n"
     ]
    }
   ],
   "source": [
    "gt_cache = model(input_ids, output_hidden_states=True)['hidden_states']\n",
    "_, my_cache = bloom.run_with_cache(input_ids)\n",
    "use_loose_bound = False\n",
    "pass_loose_bound = True\n",
    "print(\"*\"*5, \"Matching hf and T-Lens residual stream in between transformer blocks\", \"*\"*5)\n",
    "for i in range(24):\n",
    "    try:\n",
    "        torch.testing.assert_close(my_cache['resid_pre',i], gt_cache[i].cuda())\n",
    "    except:\n",
    "        max_diff = (my_cache['resid_pre',i] - gt_cache[i].cuda()).abs().max()\n",
    "        print(f\"layer {i} \\t not close, max difference: {max_diff}\")\n",
    "        use_loose_bound = True\n",
    "\n",
    "if use_loose_bound:\n",
    "    atol = rtol = 1e-3\n",
    "    print(\"*\"*5, f\"\\ttesting with atol={atol} and rtol={rtol}\\t\",\"*\"*5)\n",
    "    for i in range(24):\n",
    "        try:\n",
    "            torch.testing.assert_close(my_cache['resid_pre',i], gt_cache[i].cuda(), atol=atol, rtol=rtol)\n",
    "        except:\n",
    "            max_diff = (my_cache['resid_pre',i] - gt_cache[i].cuda()).abs().max()\n",
    "            print(f\"layer {i} \\t not close, max difference: {max_diff}\")\n",
    "            pass_loose_bound = False\n",
    "\n",
    "    if pass_loose_bound:\n",
    "        print(f\"All layers match with atol={atol} rtol={rtol}\")\n",
    "else: \n",
    "    print(\"All layers match\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "07bc8247",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T-Lens next token loss: 3.9514622688293457\n",
      "HF next token loss: 3.951453685760498\n",
      "diff in loss (abs): 8.58306884765625e-06\n"
     ]
    }
   ],
   "source": [
    "my_loss = bloom(input_ids, return_type='loss')\n",
    "print(\"T-Lens next token loss:\", my_loss.item())\n",
    "gt_outputs = model(input_ids, labels=input_ids)\n",
    "gt_loss = gt_outputs.loss\n",
    "print(\"HF next token loss:\", gt_loss.item())\n",
    "print(\"diff in loss (abs):\", (gt_loss-my_loss).abs().item())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
